package com.hsshy.beam.common.utils;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.ObjectUtil;
import cn.hutool.core.util.StrUtil;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.baomidou.mybatisplus.core.toolkit.Wrappers;
import com.baomidou.mybatisplus.core.toolkit.support.SFunction;
import com.hsshy.beam.common.annotion.Query;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.beans.PropertyDescriptor;
import java.lang.invoke.CallSite;
import java.lang.invoke.LambdaMetafactory;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import static java.lang.invoke.LambdaMetafactory.FLAG_SERIALIZABLE;
/**
 * @author hs
 * @date 2020-2-21 14:59:48
 */
public class QueryHelp {

    private final static Logger log = LoggerFactory.getLogger(QueryHelp.class);

    private final static String GET_PREFIX = "get";

    public static <Q, T> LambdaQueryWrapper<T> getLambdaQuery(Q query, Class<?> aClass) {
        LambdaQueryWrapper<T> lw = Wrappers.lambdaQuery();
        try {
            List<Field> fields = getAllFields(query.getClass(), new ArrayList<>());
            for (Field field : fields) {
                boolean accessible = field.isAccessible();
                // 设置对象的访问权限，保证对private的属性的访
                field.setAccessible(true);
                Query q = field.getAnnotation(Query.class);
                if (q != null) {
                    String propName = q.propName();
                    String blurry = q.blurry();
                    String methodName = "";
                    // 通过反射获取字段的get方法
                    if (ObjectUtil.isEmpty(propName)) {
                        PropertyDescriptor descriptor = new PropertyDescriptor(field.getName(), aClass);
                        Method readMethod = descriptor.getReadMethod();
                        methodName = readMethod.getName();
                    } else {
                        methodName = GET_PREFIX + StrUtil.upperFirst(propName);
                    }
                    SFunction func = getFunction(field, aClass, methodName);
                    // 获取对象值
                    Object val = field.get(query);
                    if (ObjectUtil.isEmpty(val)) {
                        continue;
                    }
                    // 模糊多字段
                    if (ObjectUtil.isNotEmpty(blurry)) {
                        String[] blurryArray = blurry.split(",");
                        lw.and(item->{
                            for (int i = 0; i < blurryArray.length; i++) {
                                String s = blurryArray[i];
                                String blurryMethodName = GET_PREFIX + StrUtil.upperFirst(s);
                                SFunction blurryFunc = null;
                                try {
                                    blurryFunc = getFunction(field, aClass, blurryMethodName);
                                } catch (Throwable e) {
                                    e.printStackTrace();
                                }
                                if (i == blurryArray.length - 1) {
                                    item.like(blurryFunc, val);
                                } else {
                                    item.like(blurryFunc, val).or();
                                }
                            }
                        });
                        continue;
                    }
                    switch (q.type()) {
                        case EQ:
                            lw.eq(func, val);
                            break;
                        case GE:
                            lw.ge(func, val);
                            break;
                        case LE:
                            lw.le(func, val);
                            break;
                        case LT:
                            lw.lt(func, val);
                            break;
                        case GT:
                            lw.gt(func, val);
                            break;
                        case LIKE:
                            lw.like(func, val);
                            break;
                        case LEFT_LIKE:
                            lw.likeLeft(func, val);
                            break;
                        case RIGHT_LIKE:
                            lw.likeRight(func, val);
                            break;
                        case IN:
                            if (CollUtil.isNotEmpty((Collection) val)) {
                                lw.in(func, (Collection) val);
                            }
                            break;
                        case NE:
                            lw.ne(func, val);
                            break;
                        case NOT_NULL:
                            lw.isNotNull(func);
                            break;
                        case IS_NULL:
                            lw.isNull(func);
                            break;
                        case BETWEEN:
                            List<Object> between = new ArrayList<>((List<Object>) val);
                            lw.between(func, between.get(0), between.get(1));
                            break;
                        default:
                            break;
                    }
                }
                field.setAccessible(accessible);
            }
        } catch (Exception e) {
            log.error(e.getMessage(), e);
        } catch (Throwable throwable) {
            throwable.printStackTrace();
        }
        return lw;
    }

    private static SFunction getFunction(Field field, Class<?> aClass, String methodName) throws Throwable {
        // 构建Lambda的查询对象
        Class<?> fieldType = field.getType();
        MethodHandles.Lookup lookup = MethodHandles.lookup();
        MethodType methodType = MethodType.methodType(fieldType, aClass);
        CallSite site = LambdaMetafactory.altMetafactory(lookup,
                "invoke",
                MethodType.methodType(SFunction.class),
                methodType,
                lookup.findVirtual(aClass, methodName, MethodType.methodType(fieldType)),
                methodType, FLAG_SERIALIZABLE);
        SFunction func = (SFunction) site.getTarget().invokeExact();
        return func;
    }


    private static boolean isBlank(final CharSequence cs) {
        int strLen;
        if (cs == null || (strLen = cs.length()) == 0) {
            return true;
        }
        for (int i = 0; i < strLen; i++) {
            if (!Character.isWhitespace(cs.charAt(i))) {
                return false;
            }
        }
        return true;
    }

    public static List<Field> getAllFields(Class clazz, List<Field> fields) {
        if (clazz != null) {
            fields.addAll(Arrays.asList(clazz.getDeclaredFields()));
            getAllFields(clazz.getSuperclass(), fields);
        }
        return fields;
    }
}
